library(tidyverse)
library(ragg)

prediction_path <- "./Data/ML Prediction/"

scores_list <- list.files(paste0(prediction_path, "Scores/"))

scores_list <- scores_list[grepl("ASM", scores_list)]

scores_mean <- 
  scores_list %>%
  lapply(function(df_file){
    paste0(prediction_path, "Scores/",df_file) %>% 
      read_csv() %>%
      summarise(across(everything(), ~ mean(.))) %>% 
      mutate(model = gsub("\\.csv", "",df_file))
  }) %>% 
  bind_rows()

scores_sd <-
  scores_list %>%
  lapply(function(df_file){
    paste0(prediction_path, "Scores/",df_file) %>% 
      read_csv() %>%
      summarise(across(everything(), ~ sd(.))) %>% 
      mutate(model = gsub("\\.csv", "",df_file))
  }) %>% 
  bind_rows()

scores_full <- 
  scores_list %>%
  lapply(function(df_file){
    paste0(prediction_path, "Scores/",df_file) %>% 
      read_csv() %>% 
      mutate(model = gsub("\\.csv", "",df_file))}) %>% 
  bind_rows() %>% 
  pivot_longer(-c(model), names_to = "metric")

scores <-
  left_join(scores_mean %>% 
              rename(accuracy_mean = accuracy, 
                     roc_auc_macro_mean = roc_auc_macro) %>% 
              select(model, ends_with("mean")), 
            scores_sd %>% 
              rename(accuracy_sd = accuracy, 
                     roc_auc_macro_sd = roc_auc_macro) %>% 
              select(model, ends_with("sd")))



agg_png("./Figures/Figure-1-ASM-Performance.png", width = 7, height = 4, units = "in", res = 300)

print(
  scores_full %>% 
    filter(metric %in% c("roc_auc_macro", "accuracy")) %>% 
    filter(model != "ASM_other" & model != "ASM") %>%
    mutate(model = gsub("ASM_", "", model))%>%
    mutate(model = gsub("ASM", "Full", model)) %>%
    ggplot() + 
    geom_boxplot(aes(x = model, y = value, color = metric)) + 
    theme(axis.text.x = element_text(angle = 45, hjust = 1)) + 
    lims(y = c(0,1)) + 
    labs(x = "Model", y ="Score", color = "Metric", title = "Performance of Classification") +
    scale_color_brewer(palette = "Dark2",
                       labels = c("Area Under ROC", "Accuracy"), 
                       breaks = c("roc_auc_macro", "accuracy"))
)
dev.off()
